#!/usr/bin/env python

'''
This script produces (on stdout) a TSV file summarizing the results of a
multi-task model-test run. You want to run it on a directory containing
slurm_job, 0, 1, 2, etc.'''

from __future__ import division

import argparse
from collections import OrderedDict
import datetime
import glob
import os
import sys

import testable
import tsv_glue
import u


# List of attributes to collate, and in what order. The basename is used in
# the TSV header. Note that attributes which are the same in all tasks will
# not be included in the output.
attrs = ( 'taskidx',          # special
          'DOES_NOT_EXIST',   # leave here for testing

          'memory_use_peak',
          'time_use',

          'args.model',
          'args.model_parms',
          'args.trim_head',
          'args.min_instances',

          'args.fields',
          'args.unify_fields',
          'args.srid',
          'args.tokenizer',
          'args.ngram',
          'args.test_tweet_limit',
          'args.dup_users',

          'args.training',
          'args.testing',
          'args.gap',
          'args.stride',

          'args.random_seed',

          'summary.test_ct',
          'summary.attempted_ct',

          'summary.train_tweet_ct',
          'summary.train_token_ct',
          'summary.test_tweet_ct',
          'summary.success_ct',
          'summary.ssuccess_ct',

          'summary.mntokens',
          'summary.dntokens',
          'summary.mncomponents',
          'summary.dncomponents',
          'summary.mnpoints',
          'summary.dnpoints',

          'summary.mcae',
          'summary.smcae',
          'summary.dcae',
          'summary.msae',
          'summary.smsae',
          'summary.dsae',

          'summary.mpra95',
          'summary.smpra95',
          'summary.dpra95',
          'summary.mpra90',
          'summary.smpra90',
          'summary.dpra90',
          'summary.mpra50',
          'summary.smpra50',
          'summary.dpra50',
          'summary.mcontour',
          'summary.smcontour',
          'summary.dcontour',

          'summary.mcovt95',
          'summary.smcovt95',
          'summary.dcovt95',
          'summary.mcovt90',
          'summary.smcovt90',
          'summary.dcovt90',
          'summary.mcovt50',
          'summary.smcovt50',
          'summary.dcovt50', )

# set up command line args
args = None
ap = argparse.ArgumentParser(description=__doc__,
                             formatter_class=argparse.RawTextHelpFormatter)
ap.add_argument('dir',
                metavar='DIR',
                nargs='?',
                default='.',
                help='results directory to examine')


def main():
   global args
   args = u.parse_args(ap)
   u.logging_init('runsm')
   os.chdir(args.dir)

   results = OrderedDict([[a, list()] for a in attrs])
   for taskidx in find_tasks():
      try:
         task = u.pickle_load('%s/summary' % (taskidx))
         task.summarize()
      except IOError, x:
         u.l.warning("can't load summary for task %s: %s" % (taskidx, x))
         continue
      for attr in attrs:
         if (attr == 'taskidx'):
            results[attr].append(taskidx)
         else:
            results[attr].append(getattr_(task, attr))

   # munge a few data formats
   for values in results.itervalues():
      for i in xrange(len(values)):
         # stringify dicts
         if (isinstance(values[i], dict)):
            values[i] = ' '.join('%s:%s' % (k, v)
                                 for (k, v) in sorted(values[i].iteritems()))
         # stringify sets
         if (isinstance(values[i], set)):
            values[i] = ' '.join(str(i) for i in sorted(values[i]))
         # convert timedeltas to float days
         elif (isinstance(values[i], datetime.timedelta)):
            values[i] = values[i].total_seconds() / 86400

   # remove attrs where everything has the same value
   for (k, v) in results.items():
      if (all(v[0] == i for i in v[1:])):
         del results[k]

   tsv = tsv_glue.Writer(fp=sys.stdout)
   tsv.writerow(k.split('.')[-1] for k in results.iterkeys())
   for i in xrange(len(results['taskidx'])):
      tsv.writerow(results[a][i] for a in results.iterkeys())


def find_tasks():
   return glob.glob('[0-9]*')

def getattr_(obj, attr):
   attrs = attr.split('.')
   try:
      if (len(attrs) == 1):
         return getattr(obj, attr)
      else:
         return getattr_(getattr(obj, attrs[0]), '.'.join(attrs[1:]))
   except AttributeError:
      return None

if (__name__ == '__main__' and not testable.do_script_tests()):
   main()

testable.register('')
